Tutorial 4: Statistical inference on representational geometries#

Week 1, Day 3: Comparing Artificial And Biological Networks

By Neuromatch Academy

Content creators: Veronica Bossio, Eivinas Butkus, Jasper van den Bosch

Content reviewers: Names & Surnames

Production editors: Names & Surnames


Acknowledgments: [ACKNOWLEDGMENT_INFORMATION]


Tutorial Objectives#

*Estimated timing of tutorial: 40 min

By the end of this tutorial, participants will be able to:

  1. Understanding Representational Similarity Analysis, including its theoretical foundations, practical applications, and its significance in the context of machine learning and cognitive neuroscience.

  2. Extracting neural network activations; understanding the structure of neural networks, the role of activations in interpreting neural network decisions, and practical techniques for accessing these activations.

  3. Introduction to frequentist model comparison: This part of the tutorial will cover the basics of frequentist model comparison methods. It will provide an overview of the principles underlying these methods, their applications, and the distinctions between frequentist and Bayesian approaches to model evaluation.

  4. Sources of estimation error and the motivation for model-comparative frequentist inference. Participants will learn about the three main sources of estimation error in statistical inference—measurement noise, stimulus sampling, and subject sampling. Additionally, the tutorial will explore how these sources of error justify the use of model-comparative frequentist inference, particularly through the application of the 2-factor bootstrap method. This section will detail the impact of each source of error on statistical inference and demonstrate how the 2-factor bootstrap approach helps in mitigating these errors during model comparison.


Setup#

Install dependencies#

Hide code cell source
# @title Install dependencies
# @markdown

!pip install numpy pandas torch torchvision matplotlib ipython Pillow rsatoolbox plotly networkx requests
Requirement already satisfied: numpy in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (1.26.4)
Requirement already satisfied: pandas in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (2.2.2)
Requirement already satisfied: torch in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (2.2.2)
Requirement already satisfied: torchvision in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (0.17.2)
Requirement already satisfied: matplotlib in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (3.8.3)
Requirement already satisfied: ipython in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (8.18.1)
Requirement already satisfied: Pillow in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (10.2.0)
Requirement already satisfied: rsatoolbox in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (0.1.5)
Requirement already satisfied: plotly in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (5.21.0)
Requirement already satisfied: networkx in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (3.2.1)
Requirement already satisfied: requests in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (2.31.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from pandas) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from pandas) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from pandas) (2024.1)
Requirement already satisfied: filelock in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (3.13.4)
Requirement already satisfied: typing-extensions>=4.8.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (4.11.0)
Requirement already satisfied: sympy in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (1.12)
Requirement already satisfied: jinja2 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (3.1.3)
Requirement already satisfied: fsspec in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (2024.3.1)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (12.1.105)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (12.1.105)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (12.1.105)
Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (8.9.2.26)
Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (12.1.3.1)
Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (11.0.2.54)
Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (10.3.2.106)
Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (11.4.5.107)
Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (12.1.0.106)
Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (2.19.3)
Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (12.1.105)
Requirement already satisfied: triton==2.2.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from torch) (2.2.0)
Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.4.127)
Requirement already satisfied: contourpy>=1.0.1 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from matplotlib) (1.2.1)
Requirement already satisfied: cycler>=0.10 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from matplotlib) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from matplotlib) (1.4.5)
Requirement already satisfied: packaging>=20.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from matplotlib) (24.0)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from matplotlib) (3.1.2)
Requirement already satisfied: importlib-resources>=3.2.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from matplotlib) (6.4.0)
Requirement already satisfied: decorator in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from ipython) (5.0.9)
Requirement already satisfied: jedi>=0.16 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from ipython) (0.19.1)
Requirement already satisfied: matplotlib-inline in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from ipython) (0.1.7)
Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from ipython) (3.0.43)
Requirement already satisfied: pygments>=2.4.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from ipython) (2.17.2)
Requirement already satisfied: stack-data in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from ipython) (0.6.3)
Requirement already satisfied: traitlets>=5 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from ipython) (5.14.3)
Requirement already satisfied: exceptiongroup in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from ipython) (1.2.1)
Requirement already satisfied: pexpect>4.3 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from ipython) (4.9.0)
Requirement already satisfied: scipy in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from rsatoolbox) (1.12.0)
Requirement already satisfied: scikit-learn in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from rsatoolbox) (1.4.1.post1)
Requirement already satisfied: scikit-image in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from rsatoolbox) (0.19.2)
Requirement already satisfied: h5py in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from rsatoolbox) (3.11.0)
Requirement already satisfied: tqdm in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from rsatoolbox) (4.66.2)
Requirement already satisfied: joblib in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from rsatoolbox) (1.4.0)
Requirement already satisfied: tenacity>=6.2.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from plotly) (8.2.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from requests) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from requests) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from requests) (2.2.1)
Requirement already satisfied: certifi>=2017.4.17 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from requests) (2024.2.2)
Requirement already satisfied: zipp>=3.1.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from importlib-resources>=3.2.0->matplotlib) (3.18.1)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from jedi>=0.16->ipython) (0.8.4)
Requirement already satisfied: ptyprocess>=0.5 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from pexpect>4.3->ipython) (0.7.0)
Requirement already satisfied: wcwidth in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython) (0.2.13)
Requirement already satisfied: six>=1.5 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from jinja2->torch) (2.1.5)
Requirement already satisfied: imageio>=2.4.1 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from scikit-image->rsatoolbox) (2.34.0)
Requirement already satisfied: tifffile>=2019.7.26 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from scikit-image->rsatoolbox) (2024.4.18)
Requirement already satisfied: PyWavelets>=1.1.1 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from scikit-image->rsatoolbox) (1.6.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from scikit-learn->rsatoolbox) (3.4.0)
Requirement already satisfied: executing>=1.2.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from stack-data->ipython) (2.0.1)
Requirement already satisfied: asttokens>=2.1.0 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from stack-data->ipython) (2.4.1)
Requirement already satisfied: pure-eval in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from stack-data->ipython) (0.2.2)
Requirement already satisfied: mpmath>=0.19 in /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages (from sympy->torch) (1.3.0)

Import dependencies#

Figure settings#

Hide code cell source
# @title Figure settings
# @markdown

logging.getLogger('matplotlib.font_manager').disabled = True

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

Plotting functions#

Hide code cell source
# @title Plotting functions
# @markdown

def traces_bar_and_scatter(eval_result, models, bar_color='blue'):

    evaluations = eval_result.evaluations.squeeze()
    subject_names = [f'Subject {i+1}' for i in range(evaluations.shape[1])]
    model_names = [model.name for model in models]
    df_evaluations = pd.DataFrame(data=evaluations, index=model_names, columns=subject_names)
    means = df_evaluations.mean(axis=1)
    sem = df_evaluations.sem(axis=1)

    bars_trace = go.Bar(
        x=model_names,
        y=means,
        showlegend=False,
        marker_color=bar_color
    )

    scatter_traces = []
    for subject in subject_names:
        if subject == "Subject 1":
            showlegend = True
        scatter_traces.append(go.Scatter(
            x=df_evaluations.index,
            y=df_evaluations[subject],
            mode='markers',
            marker=dict(size=5,
                        color='white',
                        line=dict(width=1)),
            showlegend=False
        ))
    blank_trace = go.Scatter(
        x=[None],  # This ensures the trace doesn't actually plot data
        y=[None],
        mode='markers',
        marker=dict(size=5, color='white', line=dict(width=1)),
        name='Each dot represents <br> a subject'
        )
    return bars_trace, scatter_traces, blank_trace

def plot_bars_and_scatter_from_trace(bars_trace, scatter_traces, blank_trace):

    fig = go.Figure()
    fig.add_trace(bars_trace)
    for trace in scatter_traces:
        fig.add_trace(trace)
    fig.add_trace(blank_trace)
    fig.update_layout(
        title="",
        xaxis_title="Model",
        yaxis_title="Cosine Similarity to Data RDMs",
        legend_title="",
        width=700,
        height=500,
        template="simple_white"
    )
    return fig

def convert_result_to_list_of_dicts(result):
    means = result.get_means()
    sems = result.get_sem()
    p_zero = result.test_zero()
    p_noise = result.test_noise()
    model_names = [model.name for model in result.models]

    results_list = []
    for i, model_name in enumerate(model_names):
        result_dict = {
            "Model": model_name,
            "Eval±SEM": f"{means[i]:.3f} ± {sems[i]:.3f}",
            "p (against 0)": "< 0.001" if p_zero[i] < 0.001 else f"{p_zero[i]:.3f}",
            "p (against NC)": "< 0.001" if p_noise[i] < 0.001 else f"{p_noise[i]:.3f}"
        }
        results_list.append(result_dict)

    return results_list

def print_results_table(table_trace):

    fig = go.Figure()
    fig.add_trace(table_trace)

    return fig

def get_trace_for_table(eval_result):

    results_list = convert_result_to_list_of_dicts(eval_result)

    table_trace = go.Table(
        header=dict(values=["Model", "Eval ± SEM", "p (against 0)", "p (against NC)"]),
        cells=dict(
            values=[
                [result["Model"] for result in results_list],  # Correctly accesses each model name
                [result["Eval±SEM"] for result in results_list],  # Correctly accesses the combined Eval and SEM value
                [result["p (against 0)"] for result in results_list],  # Accesses p-value against 0
                [result["p (against NC)"] for result in results_list]  # Accesses p-value against noise ceiling
            ],
            font=dict(size=12),  # Smaller font size for the cells
            height=27  # Smaller height for the cell rows
            )
    )
    return table_trace

def get_trace_for_noise_ceiling(noise_ceiling):

    noise_lower = np.nanmean(noise_ceiling[0])
    noise_upper = np.nanmean(noise_ceiling[1])
    #model_names = [model.name for model in models]

    noise_rectangle = dict(
            # Rectangle reference to the axes
            type="rect",
            xref="x domain",  # Use 'x domain' to span the whole x-axis
            yref="y",  # Use specific y-values for the height
            x0=0,  # Starting at the first x-axis value
            y0=noise_lower,  # Bottom of the rectangle
            x1=1,  # Ending at the last x-axis value (in normalized domain coordinates)
            y1=noise_upper,  # Top of the rectangle
            fillcolor="rgba(128, 128, 128, 0.4)",  # Light grey fill with some transparency
            line=dict(
                width=0,
                #color="rgba(128, 128, 128, 0.5)",
            )

        )
    return noise_rectangle

def plot_bars_and_scatter_with_table(eval_result, models, method, color='blue', table = True):

    if method == 'cosine':
         method_name = 'Cosine Similarity'
    elif method == 'corr':
        method_name = 'Correlation distance'
    else:
        method_name = 'Comparison method?'

    if table:
        cols = 2
        subplot_titles=["Model Evaluations", "Model Statistics"]
    else:
        cols = 1
        subplot_titles=["Model Evaluations"]

    fig = make_subplots(rows=1, cols=cols,
                        #column_widths=[0.4, 0.6],
                        subplot_titles=subplot_titles,
                        #specs=[[{"type": "bar"}, {"type": "table"}]]

                            )

    bars_trace, scatter_traces, blank_trace = traces_bar_and_scatter(eval_result, models, bar_color=color)

    fig.add_trace(bars_trace, row=1, col=1)

    for trace in scatter_traces:
        fig.add_trace(trace, row=1, col=1)

    if table:
        table_trace = get_trace_for_table(eval_result)
        fig.add_trace(table_trace, row=1, col=2)

    width = 600*cols

    fig.update_layout(
        yaxis_title=f"RDM prediction accuracy <br> (across subject mean of {method_name})",
        #legend_title="",
        width=width,
        height=600,
        template="plotly_white"
    )
    #fig.add_trace(blank_trace, row=1, col=1)


    return fig

def add_noise_ceiling_to_plot(fig, noise_ceiling):

        rectangle = get_trace_for_noise_ceiling(noise_ceiling)
        fig.add_shape(rectangle, row=1, col=1)
        return fig


def bar_bootstrap_interactive(human_rdms, models_to_compare, method):

    color = 'orange'

    button = widgets.Button(
    description="New Bootstrap Sample",
    layout=widgets.Layout(width='auto', height='auto')  # Adjust width and height as needed
    )

    #button.style.button_color = 'lightblue'  # Change the button color as you like
    button.style.font_weight = 'bold'
    button.layout.width = '300px'  # Make the button wider
    button.layout.height = '48px'  # Increase the height for a squarer appearance
    button.layout.margin = '0 0 0 0'  # Adjust margins as needed
    button.layout.border_radius = '12px'  # Rounded corners for the button

    output = widgets.Output(layout={'border': '1px solid black'})

    def generate_plot(bootstrap=False):
        if bootstrap:
                boot_rdms, idx = bootstrap_sample_rdm(human_rdms, rdm_descriptor='subject')
                result = eval.eval_fixed(models_to_compare, boot_rdms, method=method)
        else:
                result = eval.eval_fixed(models_to_compare, human_rdms, method=method)

        with output:
            clear_output(wait=True)  # Make sure to clear previous output first

            fig = plot_bars_and_scatter_with_table(result, models_to_compare, method, color)
            fig.update_layout(height=600, width=1150,
                              title=dict(text = f"Performance of Model layers for a random bootstrap sample of subjects",
                              x=0.5, y=0.95,
                              font=dict(size=20)))
            fig.show()  # Display the figure within the `with` context


    def on_button_clicked(b):
        generate_plot(bootstrap=True)

    # Now, let's create a VBox to arrange the button above the output
    vbox_layout = widgets.Layout(
        display='flex',
        flex_flow='column',
        align_items='stretch',
        width='100%',
    )


    output = widgets.Output(layout={'border': '1px solid black'})
    button.on_click(lambda b: generate_plot(bootstrap=True))  # Generate plot on button click
    vbox = widgets.VBox([button, output], layout=vbox_layout)

    # Display everything
    #display(vbox)
    display(button, output)

    generate_plot(bootstrap=False)

def show_rdm_plotly(rdms, pattern_descriptor=None, cmap='Greys',
                    rdm_descriptor=None, n_column=None, n_row=None,
                    show_colorbar=False, gridlines=None, figsize=(None, None),
                    vmin=None, vmax=None):
    # Determine the number of matrices
    mats = rdms.get_matrices()
    n_matrices = mats.shape[0]


    # Determine the number of subplots
    if n_row is None or n_column is None:
        # Calculate rows and columns to fit all matrices in a roughly square layout
        n_row = 1
        n_column = n_matrices

        # n_side = int(n_matrices ** 0.5)
        # n_row = n_side if n_side ** 2 >= n_matrices else n_side + 1
        # n_column = n_row if n_row * (n_row - 1) < n_matrices else n_row - 1

    subplot_size = 150
    fig_width = n_column * subplot_size
    fig_height = n_row * subplot_size
    subplot_titles = [f'{rdm_descriptor } {rdms.rdm_descriptors[rdm_descriptor][i]}' for i in range(n_matrices)] if rdm_descriptor else None
    # Create subplots
    fig = make_subplots(rows=n_row, cols=n_column,
                        subplot_titles=subplot_titles,
                        shared_xaxes=True, shared_yaxes=True,
                        horizontal_spacing=0.02, vertical_spacing=0.1)

    # Iterate over RDMs and add them as heatmaps
    for index in range(n_matrices):
        row, col = divmod(index, n_column)
        fig.add_trace(
            go.Heatmap(z=mats[index],
                       colorscale=cmap,
                       showscale=show_colorbar,
                       zmin=vmin, zmax=vmax),
            row=row+1, col=col+1
        )

    fig.update_layout(height=290, width=fig_width)
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)


    #fig.show()
    return fig

def show_rdm_plotly_interactive_bootstrap_patterns(rdms, pattern_descriptor=None, cmap='Greys',
                    rdm_descriptor=None, n_column=None, n_row=None,
                    show_colorbar=False, gridlines=None, figsize=(None, None),
                    vmin=None, vmax=None):


    button = widgets.Button(
    description="New Bootstrap Sample",
    layout=widgets.Layout(width='auto', height='auto')  # Adjust width and height as needed
    )

    #button.style.button_color = 'lightblue'  # Change the button color as you like
    button.style.font_weight = 'bold'
    button.layout.width = '300px'  # Make the button wider
    button.layout.height = '48px'  # Increase the height for a squarer appearance
    button.layout.margin = '0 0 0 0'  # Adjust margins as needed
    button.layout.border_radius = '12px'  # Rounded corners for the button

    #output = widgets.Output(layout={'border': '1px solid black'})
    output = widgets.Output()

    def generate_plot(bootstrap=False):
        if bootstrap:
                im_boot_rdms, pattern_idx = bootstrap_sample_pattern(rdms, pattern_descriptor='index')
        else:
                im_boot_rdms = rdms

        with output:
            clear_output(wait=True)  # Make sure to clear previous output first

            fig = show_rdm_plotly(im_boot_rdms.subset('roi', 'FFA'), rdm_descriptor='subject')
            fig.update_layout(title=dict(text = f"Bootstrapped sample of patterns",
                                    x=0.5, y=0.95,
                                    font=dict(size=20)))
            fig.show()

    def on_button_clicked(b):
        generate_plot(bootstrap=True)

    # Now, let's create a VBox to arrange the button above the output
    vbox_layout = widgets.Layout(
        display='flex',
        flex_flow='column',
        align_items='stretch',
        width='100%',
    )

    button.on_click(lambda b: generate_plot(bootstrap=True))  # Generate plot on button click
    vbox = widgets.VBox([button, output], layout=vbox_layout)

    # Display everything
    display(vbox)
    #display(button, output)

    generate_plot(bootstrap=False)

def plot_model_comparison_trans(result, sort=False, colors=None,
                          alpha=0.01, test_pair_comparisons=True,
                          multiple_pair_testing='fdr',
                          test_above_0=True,
                          test_below_noise_ceil=True,
                          error_bars='sem',
                          test_type='t-test'):


    # Prepare and sort data
    evaluations = result.evaluations
    models = result.models
    noise_ceiling = result.noise_ceiling
    method = result.method
    model_var = result.model_var
    diff_var = result.diff_var
    noise_ceil_var = result.noise_ceil_var
    dof = result.dof

    while len(evaluations.shape) > 2:
        evaluations = np.nanmean(evaluations, axis=-1)

    evaluations = evaluations[~np.isnan(evaluations[:, 0])]
    n_bootstraps, n_models = evaluations.shape
    perf = np.mean(evaluations, axis=0)

    noise_ceiling = np.array(noise_ceiling)
    sort = 'unsorted'
    # run tests
    if any([test_pair_comparisons,
            test_above_0, test_below_noise_ceil]):
        p_pairwise, p_zero, p_noise = all_tests(
            evaluations, noise_ceiling, test_type,
            model_var=model_var, diff_var=diff_var,
            noise_ceil_var=noise_ceil_var, dof=dof)

    if error_bars:
        limits = get_errorbars(model_var, evaluations, dof, error_bars,
                               test_type)
        if error_bars.lower() == 'sem':
            limits = limits[0,:]

    #return limits, perf

    fig = make_subplots(rows=2, cols=1,
                        row_heights=[0.3, 0.7],
                        vertical_spacing=0.05,
                        subplot_titles=("Model Evaluations", ''),
                        shared_xaxes=True,
                        )


    # antique_colors = plotly.colors.qualitative.Antique  # Get the Antique color palette
    # n_colors = len(antique_colors)  # Number of colors in the palette

    n_colors_needed = len(models)
    # Sample n_colors_needed colors from the Plasma color scale
    plasma_scale = plotly.colors.get_colorscale('Bluered')  # Retrieve the color scale
    color_indices = np.linspace(0, 1, n_colors_needed)  # Evenly spaced indices between 0 and 1
    sampled_colors = plotly.colors.sample_colorscale(plasma_scale, color_indices)  # Sample colors

    for i, (perf_val, model) in enumerate(zip(perf, models)):
        name = model.name
        #bar_color = antique_colors[i % n_colors]

        fig.add_trace(
            go.Bar(
                x=[name],  # x-axis position
                y=[perf_val],  # Performance value
                error_y=dict(type='data',
                            array=limits, visible=True, color='black'),  # Adding error bars
                marker_color=sampled_colors[i],  # Cycle through colors
                name=name
            ),
            row=2, col=1  # Assuming a single subplot for simplicity
        )


    fig.update_layout(width=600, height=700, showlegend=False, template='plotly_white')
    # return fig


    model_significant = p_zero < alpha / n_models
    significant_indices = [i for i, significant in enumerate(model_significant) if significant]
    symbols = {'dewdrops': 'circle', 'icicles': 'diamond-tall'}

    fig.add_trace(
        go.Scatter(
            x=[models[i].name for i in significant_indices],  # X positions of significant models
            y=[0.0005] * len(significant_indices),  # Y positions (at 0 for visualization)
            mode='markers',
            marker=dict(symbol=symbols['dewdrops'],  # Example using 'triangle-up'
                        size=9,
                        color='white'),  # Example using 'triangle-up'
            showlegend=False
        ),
        row=2, col=1
    )

    # Plot noise ceiling
    if noise_ceiling is not None:

        noise_lower = np.nanmean(noise_ceiling[0])
        noise_upper = np.nanmean(noise_ceiling[1])
        model_names = [model.name for model in models]

        fig.add_shape(
                # Rectangle reference to the axes
                type="rect",
                xref="x domain",  # Use 'x domain' to span the whole x-axis
                yref="y",  # Use specific y-values for the height
                x0=0,  # Starting at the first x-axis value
                y0=noise_lower,  # Bottom of the rectangle
                x1=1,  # Ending at the last x-axis value (in normalized domain coordinates)
                y1=noise_upper,  # Top of the rectangle
                fillcolor="rgba(128, 128, 128, 0.5)",  # Light grey fill with some transparency
                line=dict(
                    color='gray',
                ),
                opacity=0.5,
                layer="below",  # Ensure the shape is below the data points
                row=2, col=1  # Specify the subplot where the shape should be added

            )

    test_below_noise_ceil = 'dewdrops'  # Example, can be True/'dewdrops'/'icicles'
    model_below_lower_bound = p_noise < (alpha / n_models)

    significant_indices_below = [i for i, below in enumerate(model_below_lower_bound) if below]

    # Choose the symbol based on the test_below_noise_ceil
    if test_below_noise_ceil is True or test_below_noise_ceil.lower() == 'dewdrops':
        symbol = 'circle-open'  # Use open circle as a proxy for dewdrops
    elif test_below_noise_ceil.lower() == 'icicles':
        symbol = 'diamond-open'  # Use open diamond as a proxy for icicles
    else:
        raise ValueError('Argument test_below_noise_ceil is incorrectly defined as ' + test_below_noise_ceil)

    symbol = 'triangle-down'
#    y_position_below = noise_lower + 0.0005  # Adjust based on your visualization needs

    #y_positions_below = [perf[i] for i in significant_indices_below]  # Extracting perf values for significant models
    y_positions_below = [noise_lower-0.005] * len(significant_indices_below)  # Adjust based on your visualization needs
    fig.add_trace(
        go.Scatter(
            x=[models[i].name for i in significant_indices_below],  # X positions of significant models
            y= y_positions_below, #* len(significant_indices_below),  # Y positions slightly above noise_lower
            mode='markers',
            marker=dict(symbol=symbol, size=7, color='gray'),  # Customizing marker appearance
            showlegend=False
        ),
        row=2, col=1
    )

    #return fig

    # Pairwise model comparisons
    if test_pair_comparisons:
        if test_type == 'bootstrap':
            model_comp_descr = 'Model comparisons: two-tailed bootstrap, '
        elif test_type == 't-test':
            model_comp_descr = 'Model comparisons: two-tailed t-test, '
        elif test_type == 'ranksum':
            model_comp_descr = 'Model comparisons: two-tailed Wilcoxon-test, '
        n_tests = int((n_models ** 2 - n_models) / 2)
        if multiple_pair_testing is None:
            multiple_pair_testing = 'uncorrected'
        if multiple_pair_testing.lower() == 'bonferroni' or \
           multiple_pair_testing.lower() == 'fwer':
            significant = p_pairwise < (alpha / n_tests)
        elif multiple_pair_testing.lower() == 'fdr':
            ps = batch_to_vectors(np.array([p_pairwise]))[0][0]
            ps = np.sort(ps)
            criterion = alpha * (np.arange(ps.shape[0]) + 1) / ps.shape[0]
            k_ok = ps < criterion
            if np.any(k_ok):
                k_max = np.max(np.where(ps < criterion)[0])
                crit = criterion[k_max]
            else:
                crit = 0
            significant = p_pairwise < crit
        else:
            if 'uncorrected' not in multiple_pair_testing.lower():
                raise ValueError(
                    'plot_model_comparison: Argument ' +
                    'multiple_pair_testing is incorrectly defined as ' +
                    multiple_pair_testing + '.')
            significant = p_pairwise < alpha
        model_comp_descr = _get_model_comp_descr(
            test_type, n_models, multiple_pair_testing, alpha,
            n_bootstraps, result.cv_method, error_bars,
            test_above_0, test_below_noise_ceil)


        # new_fig_nili = plot_nili_bars_plotly(fig, significant, models, version=1)
        # new_fig_gol = plot_golan_wings_plotly(fig, significant, perf, models)

        new_fig_metro = plot_metroplot_plotly(fig, significant, perf, models, sampled_colors)

        return new_fig_metro

def plot_golan_wings_plotly(original_fig, significant, perf, models):
    # First, create a deep copy of the original figure to preserve its state
    fig = deepcopy(original_fig)

    n_models = len(models)
    model_names = [m.name for m in models]
    # Use the Plotly qualitative color palette
    colors = plotly.colors.qualitative.Plotly

    k = 1  # Vertical position tracker
    marker_size = 8  # Size of the markers
    for i in range(n_models):

        js = np.where(significant[i, :])[0]  # Indices of models significantly different from model i
        if len(js) > 0:
            for j in js:
                # Ensure cycling through the color palette
                color = colors[i % len(colors)]
                fig.add_trace(go.Scatter(x=[model_names[i], model_names[j]],
                                            y=[k, k],
                                        mode='lines',
                                        line=dict(color=color, width=2)
                                        ),
                                        row=1, col=1)
                fig.add_trace(go.Scatter(x=[model_names[i]], y=[k],
                                        mode='markers',
                                        marker=dict(symbol='circle', color=color, size=10,
                                                    line=dict(color=color, width=2))
                                        ),
                                        row=1, col=1)

                if perf[i] > perf[j]:
                    # Draw downward feather
                    fig.add_trace(go.Scatter(x=[model_names[j]],
                                            y=[k],
                                            mode='markers',
                                            marker=dict(symbol='triangle-right', color=color, size=marker_size,
                                                        line=dict(color=color, width=2))
                                            ),
                                            row=1, col=1)
                elif perf[i] < perf[j]:
                    # Draw upward feather
                    fig.add_trace(go.Scatter(x=[model_names[i], model_names[j]],
                                             y=[k, k],
                                            mode='lines',
                                            line=dict(color=color, width=2)
                                            ),
                                            row=1, col=1)
                    fig.add_trace(go.Scatter(x=[model_names[j]], y=[k],
                                            mode='markers',
                                            marker=dict(symbol='triangle-left', color=color, size=marker_size,
                                                        line=dict(color=color, width=2))
                                            ),
                                            row=1, col=1)
            k += 1  # Increment vertical position after each model's wings are drawn

    # Update y-axis to fit the wings
    fig.update_xaxes(showgrid=False, showticklabels=False, row=1, col=1)
    fig.update_yaxes(showgrid=False, showticklabels=False, row=1, col=1)

    return fig


def plot_metroplot_plotly(original_fig, significant, perf, models, sampled_colors):
    # First, create a deep copy of the original figure to preserve its state
    fig = deepcopy(original_fig)

    # n_colors_needed = len(models)
    # # Sample n_colors_needed colors from the Plasma color scale
    # plasma_scale = plotly.colors.get_colorscale('Bluered')  # Retrieve the color scale
    # color_indices = np.linspace(0, 1, n_colors_needed)  # Evenly spaced indices between 0 and 1
    # sampled_colors = plotly.colors.sample_colorscale(plasma_scale, color_indices)  # Sample colors

    n_models = len(models)
    model_names = [m.name for m in models]
    # Use the Plotly qualitative color palette
    colors = plotly.colors.qualitative.Antique

    k = 1  # Vertical position tracker
    marker_size = 8  # Size of the markers
    for i, (model, color) in enumerate(zip(model_names,sampled_colors)):
    # for i, (model, color) in enumerate(zip(model_names,colors)):

        js = np.where(significant[i, :])[0]  # Indices of models significantly different from model i
        j_worse = np.where(perf[i] > perf)[0]

        worse_models = [model_names[j] for j in j_worse]  # Model names that performed worse
        metropoints = worse_models + [model]  # Model names to plot on the y-axis
        #marker_symbols = ['circle-open' if point != model else 'circle' for point in metropoints]
        marker_colors = ['white' if point != model else color for point in metropoints]  # Fill color for markers



        fig.add_trace(go.Scatter(
                y = np.repeat(model,  len(metropoints)),
                #y = df_model['Model2'],
                x = metropoints,
                mode = 'lines+markers',
                marker = dict(
                    color = marker_colors,
                    symbol = 'circle',
                    size = 10,
                    line = dict(width=2, color=color)
                ),
                line=dict(width=2, color=color),
                showlegend = False),
                row = 1, col = 1,

            )

    # Update y-axis to fit the wings
    fig.update_xaxes(showgrid=False, showticklabels=False, row=1, col=1)
    fig.update_yaxes(showgrid=False, showticklabels=False, row=1, col=1)

    return fig

def plot_nili_bars_plotly(original_fig, significant, models, version=1):

    fig = deepcopy(original_fig)

    k = 1  # Vertical position tracker
    ns_col = 'rgba(128, 128, 128, 0.5)'  # Non-significant comparison color
    w = 0.2  # Width for nonsignificant comparison tweaks
    model_names = [m.name for m in models]

    for i in range(significant.shape[0]):
        drawn1 = False
        for j in range(i + 1, significant.shape[0]):
            if version == 1 and significant[i, j]:
                # Draw a line for significant differences
                fig.add_shape(type="line",
                              x0=i, y0=k, x1=j, y1=k,
                              line=dict(color="black", width=2),
                              xref="x1", yref="y1",
                              row=1, col=1)
                k += 1
                drawn1 = True
            elif version == 2 and not significant[i, j]:
                # Draw a line for non-significant differences
                fig.add_shape(type="line",
                              x0=i, y0=k, x1=j, y1=k,
                              line=dict(color=ns_col, width=2),
                              xref="x1", yref="y1",
                              row=1, col=1)
                # Additional visual tweaks for non-significant comparisons
                fig.add_annotation(x=(i+j)/2, y=k, text="n.s.",
                                   showarrow=False,
                                   font=dict(size=8, color=ns_col),
                                   xref="x1", yref="y1",
                                   row=1, col=1)
                k += 1
                drawn1 = True

        if drawn1:
            k += 1  # Increase vertical position after each row of comparisons

    fig.update_xaxes(showgrid=False, showticklabels=False, row=1, col=1)
    fig.update_yaxes(showgrid=False, showticklabels=False, row=1, col=1)

    fig.update_layout(height=700)  # Adjust as necessary
    return fig


def _get_model_comp_descr(test_type, n_models, multiple_pair_testing, alpha,
                          n_bootstraps, cv_method, error_bars,
                          test_above_0, test_below_noise_ceil):
    """constructs the statistics description from the parts

    Args:
        test_type : String
        n_models : integer
        multiple_pair_testing : String
        alpha : float
        n_bootstraps : integer
        cv_method : String
        error_bars : String
        test_above_0 : Bool
        test_below_noise_ceil : Bool

    Returns:
        model

    """
    if test_type == 'bootstrap':
        model_comp_descr = 'Model comparisons: two-tailed bootstrap, '
    elif test_type == 't-test':
        model_comp_descr = 'Model comparisons: two-tailed t-test, '
    elif test_type == 'ranksum':
        model_comp_descr = 'Model comparisons: two-tailed Wilcoxon-test, '
    n_tests = int((n_models ** 2 - n_models) / 2)
    if multiple_pair_testing is None:
        multiple_pair_testing = 'uncorrected'
    if multiple_pair_testing.lower() == 'bonferroni' or \
       multiple_pair_testing.lower() == 'fwer':
        model_comp_descr = (model_comp_descr
                            + 'p < {:<.5g}'.format(alpha)
                            + ', Bonferroni-corrected for '
                            + str(n_tests)
                            + ' model-pair comparisons')
    elif multiple_pair_testing.lower() == 'fdr':
        model_comp_descr = (model_comp_descr +
                            'FDR q < {:<.5g}'.format(alpha) +
                            ' (' + str(n_tests) +
                            ' model-pair comparisons)')
    else:
        if 'uncorrected' not in multiple_pair_testing.lower():
            raise ValueError(
                'plot_model_comparison: Argument ' +
                'multiple_pair_testing is incorrectly defined as ' +
                multiple_pair_testing + '.')
        model_comp_descr = (model_comp_descr +
                            'p < {:<.5g}'.format(alpha) +
                            ', uncorrected (' + str(n_tests) +
                            ' model-pair comparisons)')
    if cv_method in ['bootstrap_rdm', 'bootstrap_pattern',
                     'bootstrap_crossval']:
        model_comp_descr = model_comp_descr + \
            '\nInference by bootstrap resampling ' + \
            '({:<,.0f}'.format(n_bootstraps) + ' bootstrap samples) of '
    if cv_method == 'bootstrap_rdm':
        model_comp_descr = model_comp_descr + 'subjects. '
    elif cv_method == 'bootstrap_pattern':
        model_comp_descr = model_comp_descr + 'experimental conditions. '
    elif cv_method in ['bootstrap', 'bootstrap_crossval']:
        model_comp_descr = model_comp_descr + \
            'subjects and experimental conditions. '
    if error_bars[0:2].lower() == 'ci':
        model_comp_descr = model_comp_descr + 'Error bars indicate the'
        if len(error_bars) == 2:
            CI_percent = 95.0
        else:
            CI_percent = float(error_bars[2:])
        model_comp_descr = (model_comp_descr + ' ' +
                            str(CI_percent) + '% confidence interval.')
    elif error_bars.lower() == 'sem':
        model_comp_descr = (
            model_comp_descr +
            'Error bars indicate the standard error of the mean.')
    elif error_bars.lower() == 'sem':
        model_comp_descr = (model_comp_descr +
                            'Dots represent the individual model evaluations.')
    if test_above_0 or test_below_noise_ceil:
        model_comp_descr = (
            model_comp_descr +
            '\nOne-sided comparisons of each model performance ')
    if test_above_0:
        model_comp_descr = model_comp_descr + 'against 0 '
    if test_above_0 and test_below_noise_ceil:
        model_comp_descr = model_comp_descr + 'and '
    if test_below_noise_ceil:
        model_comp_descr = (
            model_comp_descr +
            'against the lower-bound estimate of the noise ceiling ')
    if test_above_0 or test_below_noise_ceil:
        model_comp_descr = (model_comp_descr +
                            'are Bonferroni-corrected for ' +
                            str(n_models) + ' models.')
    return model_comp_descr

Data retrieval#

Hide code cell source
# @title Data retrieval
# @markdown

def download_file(fname, url, expected_md5):
    """
    Downloads a file from the given URL and saves it locally.
    """
    if not os.path.isfile(fname):
        try:
            r = requests.get(url)
        except requests.ConnectionError:
            print("!!! Failed to download data !!!")
            return
        if r.status_code != requests.codes.ok:
            print("!!! Failed to download data !!!")
            return
        if hashlib.md5(r.content).hexdigest() != expected_md5:
            print("!!! Data download appears corrupted !!!")
            return
        with open(fname, "wb") as fid:
            fid.write(r.content)
        print(f"{fname} has been downloaded successfully.")

def extract_zip(zip_fname):
    """
    Extracts a ZIP file to the current directory.
    """
    with zipfile.ZipFile(zip_fname, 'r') as zip_ref:
        zip_ref.extractall(".")
        print(f"{zip_fname} has been extracted successfully.")

# Details for the zip files to be downloaded and extracted
zip_files = [
    {
        "fname": "fmri_patterns.zip",
        "url": "https://osf.io/7jc3n/download",
        "expected_md5": "c21395575573c62129dc7e9d806f0b5e"
    },
    {
        "fname": "images.zip",
        "url": "https://osf.io/zse8u/download",
        "expected_md5": "ecb0d1a487e90be908ac24c2b0b10fc3"
    }
]

# New addition for other files to be downloaded, specifically non-zip files
image_files = [
    {
        "fname": "NSD.png",
        "url": "https://osf.io/69tj8/download",
        "expected_md5": "a5ff07eb016d837da2624d8e511193ca"
    }
]

# Process zip files: download and extract
for zip_file in zip_files:
    download_file(zip_file["fname"], zip_file["url"], zip_file["expected_md5"])
    extract_zip(zip_file["fname"])

# Process image files: download only
for image_file in image_files:
    download_file(image_file["fname"], image_file["url"], image_file["expected_md5"])
fmri_patterns.zip has been downloaded successfully.
fmri_patterns.zip has been extracted successfully.
images.zip has been downloaded successfully.
images.zip has been extracted successfully.
NSD.png has been downloaded successfully.

Section 1: Data Acquisition#

Initially, we will load fMRI data specifically targeting the V1 (primary visual cortex) and FFA (fusiform face area) regions, corresponding to a predefined set of images. This step is crucial for understanding how these brain areas respond to visual stimuli.

# Some constants
SUBJECTS = list(range(1, 9)) # There are 8 subjects
ROIS = ["V1", "FFA"] # Regions of interest in fmri data
IMAGES_DIR = pathlib.Path('images')
FMRI_PATTERNS_DIR = pathlib.Path('fmri_patterns')

Show image#

Hide code cell source
# @title Show image
# @markdown

display(IMG(filename="NSD.png"))
../../../_images/74eb65fc6b53ef4367d279cdc6d1d704402b402505d242a4d28ca00c7ab73396.png

We take fMRI response patterns from the Natural Scene Dataset. NSD is a large 7T fMRI dataset of 8 adults viewing more than 73 thousand photos of natural scenes.

We have taken a small subset of 90 images from NSD and have pre-extracted the fRMI data for V1 and Fusiform Face Area (FFA) from 8 subjects.

Loading the images#

First, let’s load the 90 image files with the Pillow Image class.

image_paths = sorted(IMAGES_DIR.glob("*.png")) # Find all pngs file paths in the image directory
images = [Image.open(p).convert('RGB') for p in image_paths] # Load them as Image objects
np.array(images[0]).shape # Dimensions of the image array: width x height x channels (RGB)
(425, 425, 3)

Now let’s take a look at these images. Notice that the first 45 images we selected have no faces, while the other 45 do have a face in them! So we should expect to see a 2x2 block pattern in the Fusiform Face Area (FFA) representational dissimilarity matrices (RDMs).

fig, ax = plt.subplots(9, 10, figsize=(10, 10))

for i, img in enumerate(images):
    ax[i//10, i%10].imshow(img)
    ax[i//10, i%10].axis('off')
    ax[i//10, i%10].text(0, 0, str(i+1), color='black', fontsize=12)
plt.show()
../../../_images/06416ee93f46b6091b00625105301d53342461a2b15c1b7e1b46edbad78ed720.png

Loading fMRI patterns from the NSD datset#

Let’s now load the fMRI patterns from the NSD dataset for these 90 images. We have pre-extracted the patterns, so we just need to load numpy arrays from “.npy” files.

Note that we have 8 subjects and our regions of interest (ROIs) are V1 and FFA.

# Loading fmri data
fmri_patterns = {}
for subject in SUBJECTS:
    fmri_patterns[subject] = {}

    for roi in ROIS:
        fmri_patterns[subject][roi] = {}

        full_data = np.load(FMRI_PATTERNS_DIR / f"subj{subject}_{roi}.npy")
        fmri_patterns[subject][roi] = full_data

# This is how we can index into subject 5 FFA patterns for all the images
fmri_patterns[5]["V1"].shape # Number of images x number of voxels
(90, 2950)

Let’s now take a look at the fmri pattern of two non-face images and two face images.

def plot_fmri_pattern(subject, roi, image_idx, ax):
    pattern = fmri_patterns[subject][roi][image_idx]
    pattern = np.pad(pattern, (0, 50 - pattern.shape[0] % 50))
    pattern = pattern.reshape(-1, 50)

    ax.imshow(pattern, aspect='auto', vmin=-2.2, vmax=2.3, cmap='bwr')
    ax.set_title(f"Subject {subject}, ROI {roi}, Image {image_idx}")

fig, axs = plt.subplots(1, 4, figsize=(15, 3))

subject = 1
roi = "FFA"

# non-face images
plot_fmri_pattern(subject, roi, 1, axs[1])
plot_fmri_pattern(subject, roi, 3, axs[0])

# face images
plot_fmri_pattern(subject, roi, 57, axs[2])
plot_fmri_pattern(subject, roi, 75, axs[3])

plt.show()
../../../_images/4e90d04314336d46b844a0293ff3f4a319e775a22ede362483dfccba776e3559.png

Section 2. Get artificial neural network activations#

Now that we have fMRI patterns, we want to explain this data using computational models.

In this tutorial, we will take our models to be layers of AlexNet.

Comparing LeNet architecture to AlexNet. Image from Dive Into Deep Learning book.

AlexNet is a famous convolutional neural network that won the ImageNet challenge in 2012 and started the “deep learning revolution”.

We load a version of AlexNet that is already pre-trained on ImageNet. This step may take a minute, feel free to read ahead.

# Load AlexNet model pretrained on ImageNet
alexnet = torchvision.models.alexnet(weights="IMAGENET1K_V1")
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /home/runner/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
  0%|          | 0.00/233M [00:00<?, ?B/s]
  6%|▌         | 12.8M/233M [00:00<00:01, 134MB/s]
 17%|█▋        | 40.6M/233M [00:00<00:00, 226MB/s]
 28%|██▊       | 66.1M/233M [00:00<00:00, 245MB/s]
 40%|███▉      | 92.2M/233M [00:00<00:00, 256MB/s]
 51%|█████     | 118M/233M [00:00<00:00, 262MB/s] 
 62%|██████▏   | 144M/233M [00:00<00:00, 266MB/s]
 73%|███████▎  | 170M/233M [00:00<00:00, 267MB/s]
 84%|████████▍ | 196M/233M [00:00<00:00, 268MB/s]
 95%|█████████▍| 221M/233M [00:00<00:00, 268MB/s]
100%|██████████| 233M/233M [00:00<00:00, 258MB/s]

To pass images through the model, we need to preprocess them to be in the same format as the images passed shown to the model during training.

With AlexNet this includes resizing the images to 224x224 and normalizing their color channels to particular values. We also need to turn them into PyTorch tensors.

# Preprocess NSD images as input to alexnet
# We need to use the exact same preprocessing as was used to train AlexNet
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)), # Resize the images to 224x24 pixels
    torchvision.transforms.ToTensor(), # Convert the images to a PyTorch tensor
    torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalize the image color channels
])

images_tensor = torch.stack([transform(img) for img in images])
print(images_tensor.shape) # (number of images, channels, height, width)
torch.Size([90, 3, 224, 224])

Let’s inspect AlexNet architecture to select some of the layers as our “models”:

print("Architecture of AlexNet:")
print(alexnet)

node_names = get_graph_node_names(alexnet) # this returns a tuple with layer names for the forward pass and the backward pass
print("\nGraph node names (layers) in the forward pass:")
print(node_names[0]) # forward pass layer names
Architecture of AlexNet:
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Graph node names (layers) in the forward pass:
['x', 'features.0', 'features.1', 'features.2', 'features.3', 'features.4', 'features.5', 'features.6', 'features.7', 'features.8', 'features.9', 'features.10', 'features.11', 'features.12', 'avgpool', 'flatten', 'classifier.0', 'classifier.1', 'classifier.2', 'classifier.3', 'classifier.4', 'classifier.5', 'classifier.6']

We extract activations from different layers of AlexNet looking at the same images that we got our NSD fMRI patterns from:

# Make hooks in alexnet to extract activations from different layers
return_nodes = {
    "features.2": "conv1",
    "features.5": "conv2",
    "features.7": "conv3",
    "features.9": "conv4",
    "features.12": "conv5",
    "classifier.1": "fc6",
    "classifier.4": "fc7",
    "classifier.6": "fc8"
}
feature_extractor = create_feature_extractor(alexnet, return_nodes=return_nodes)
# Extract activations from alexnet
alexnet_activations = feature_extractor(images_tensor)

# Convert to numpy arrays
for layer, activations in alexnet_activations.items():

    act = activations.detach().numpy().reshape(len(images), -1)
    alexnet_activations[layer] = act  # Keep original data under 'all'

alexnet_activations['conv1'].shape # number of images x number of neurons in conv1 layer
(90, 46656)

Section 3. Create representational dissimilarity matrices (RDMs)#

Now that we have fMRI patterns and AlexNet activations, the first step in representation similarity analysis (RSA) is to compute the representational dissimilarity matrices (RDMs). RSA characterizes the representational geometry of the brain region of interest (ROI) by estimating the representational distance for each pair of experimental conditions (e.g. different images).

RDMs represent how dissimilar neural activity patterns or model activations are for each stimulus. In our case, these will be 90x90 image-by-image matrices representing how dissimilar fMRI patterns or AlexNet layer activations to each image.

For instance, we expect that in FFA, there will be a huge distance between the 45 face and 45 non-face images (so we expect to see a 2x2 block pattern inside the RDM).

Creating RSA toolbox datasets#

First, let’s wrap our neural and model data in Dataset objects to use the RSA toolbox.

# Create RSA datasets for each subject and ROI
fmri_datasets = {}

for subject in SUBJECTS:
    fmri_datasets[subject] = {}

    for roi in ROIS:
        fmri_datasets[subject][roi] = {}

        # for stimset in ['D1', 'D2', 'all']
        measurements = fmri_patterns[subject][roi]
        fmri_datasets[subject][roi] = rsa.data.Dataset(measurements=measurements,
                                                            descriptors = {'subject': subject, 'roi': roi},
                                                            obs_descriptors = {'image': np.arange(measurements.shape[0])},
                                                            ## this assumes that the patterns are all in the same order? - jasper
                                                            channel_descriptors = {'voxel': np.arange(measurements.shape[1])})
# Create RSA datasets for alexnet activations
alexnet_datasets = {}

for layer, activations in alexnet_activations.items():
    alexnet_datasets[layer] = {}

    # For stimset in ['D1', 'D2', 'all', 'random']:
    measurements = activations
    alexnet_datasets[layer] = rsa.data.Dataset(measurements=measurements,
                                            descriptors={'layer': layer},
                                            obs_descriptors={'image': np.arange(measurements.shape[0])},
                                            channel_descriptors={'channel': np.arange(measurements.shape[1])})

Computing the RDMs#

Let’s compute RDMs for fMRI patterns and AlexNet activations.

# Compute rdms for each subject and ROI
fmri_rdms = {}
fmri_rdms_list = []

for subject in SUBJECTS:
    fmri_rdms[subject] = {}

    for roi in ROIS:
        fmri_rdms[subject][roi] = {}

        # For stimset in ['D1', 'D2']:
        fmri_rdms[subject][roi] = rsa.rdm.calc_rdm(fmri_datasets[subject][roi])
        fmri_rdms_list.append(fmri_rdms[subject][roi])

Exercise 2: use the RSA toolbbox to compute the RDMs for the layers of Alexnet#

#################################################
## TODO for students: fill in the missing variables ##
# Fill out function and remove
raise NotImplementedError("Student exercise: fill in the missing variables")
#################################################

# Compute rdms for each layer of AlexNet
alexnet_rdms_dict = {}
for layer, dataset in alexnet_datasets.items():
    alexnet_rdms_dict[layer] = ...

Click for solution

Visualizing human RDMs#

Here we use methods on the rsatoolbox RDMs object to select a subset of the RDMs.

fmri_rdms = rsa.rdm.concat(fmri_rdms_list)
ffa_rdms = fmri_rdms.subset('roi', 'FFA')
show_rdm_plotly(ffa_rdms, rdm_descriptor='subject')

As predicted above, you can see a 2x2 block-like pattern in the FFA fMRI pattern RDMs.

This is because we have 45 non-face images followed by 45 face images.

The lighter regions indicate larger representational distances.

fmri_rdms = rsa.rdm.concat(fmri_rdms_list)
fig = rsa.vis.rdm_plot.show_rdm(ffa_rdms, rdm_descriptor='subject')[0]
../../../_images/847c02eb74c165869cb5227457db0c8ed75d49af4e0a75cf180ecb8d0f88ae8d.png

Exercise 3: Visualize the RDMs for the fmri patterns from the V1 region#

#################################################
## TODO for students: fill in the missing variables ##
# Fill out function and remove
raise NotImplementedError("Student exercise: fill in the missing variables")
#################################################

fmri_rdms = rsa.rdm.concat(fmri_rdms_list)
v1_rdms = ...
show_rdm_plotly(v1_rdms, rdm_descriptor='subject')

Click for solution

Visualizing AlexNet RDMs#

Let’s look at RDMs for different layers of AlexNet.

alexnet_rdms = rsa.rdm.concat(alexnet_rdms_dict.values())
fig = rsa.vis.rdm_plot.show_rdm(alexnet_rdms, rdm_descriptor='layer')[0]

We see a similar pattern emerge clustering face and non-face images in fully connected “fc6”, “fc7”, “fc8” layers.

AlexNet seems to “care” about faces too, at least to some extent.

show_rdm_plotly(alexnet_rdms, rdm_descriptor='layer')

Section 4. RSA: model comparison and statistical inference#

In the second step of RSA, each model is evaluated by the accuracy of its prediction of the data RDM. To this end, we will use the RDMs we computed for each model representation.

Each model’s prediction of the data RDM is evaluated using an RDM comparator, in this case we will use the correlation coefficient.

First, let’s look at the performance of different Alexnet layers against all the subjects

# Get the Model objects to use the rsa toolbox for model comparisons
models = []
for layer, rdm in alexnet_rdms_dict.items():
    models.append(rsa.model.ModelFixed(rdm=rdm, name=layer))
roi = 'FFA'
human_rdms = fmri_rdms.subset('roi', roi)
models_to_compare = models

method = 'corr'
result =  rsa.inference.evaluate.eval_fixed(models_to_compare, human_rdms, method=method) # get the performance of the models compared to the fMRI data of the first 3 subjects for the FFA ROI

fig = plot_bars_and_scatter_with_table(result, models_to_compare, method, table = False)
fig.update_layout(title=dict(text = f"Performance of AlexNet layers on stimuli <br> in {roi} ROI for original set of subjects",
                              x=0.5, y=0.95,
                              font=dict(size=15)))
add_noise_ceiling_to_plot(fig, result.noise_ceiling)
#fig.update_layout(width=600)
fig.show()

In the plot, each data point represents the representational dissimilarity matrix (RDM) for an individual subject. The observed variability reflects the extent to which our models accurately predict neural activity patterns across different individuals.

Our goal is to determine how these results might generalize to a new cohort of subjects and new sets of stimuli. Since we cannot practically rerun the experiment countless times with fresh subjects and stimuli, we turn to computational simulations.

To achieve this, we will employ bootstrap resampling—a statistical technique that involves resampling our existing dataset with replacement to generate multiple simulated samples. This approach allows us to mimic the process of conducting the experiment anew with different subjects.

First let’s focus on generalization to new subjects. By bootstrapping the subject dataset, for each simulated sample, we can compute the predictive accuracy of our models on the subjects’ RDMs. After running many simulations, we will accumulate a distribution of accuracy estimates. This distribution will enable us to perform statistical inferences about our models’ generalizability to new subjects. (Later, we will address the problem of generalizing to new stimuli as well.)

Let’s simulate a ‘new’ sample of subjects by bootstrap resampling, using the RSA toolbox.

boot_rdms, idx = rsa.inference.bootstrap_sample_rdm(human_rdms, rdm_descriptor='subject')

Now we plot the RDMs of the bootstrapped sample

Each RDM is a subject (note that some subjects might be repeated and some might be missing in the bootstrapped sample)

fig1 = show_rdm_plotly(fmri_rdms.subset('roi', 'FFA'), rdm_descriptor='subject')
fig1.update_layout(title=dict(text = f"Original sample of subjects",
                             x=0.5, y=0.95,
                             font=dict(size=20)))
fig2 = show_rdm_plotly(boot_rdms.subset('roi', 'FFA'), rdm_descriptor='subject')
fig2.update_layout(title=dict(text = f"Bootstrapped sample of subjects",
                             x=0.5, y=0.95,
                             font=dict(size=20)))
fig1.show()
fig2.show()

Let’s see the model performance on different bootstrap resampled subject sets

boot_rdms, idx = rsa.inference.bootstrap_sample_rdm(human_rdms, rdm_descriptor='subject')
eval_result = rsa.inference.evaluate.eval_fixed(models_to_compare, boot_rdms, method=method)
fig = plot_bars_and_scatter_with_table(eval_result, models, method, color='blue', table = False)
fig

If you run the cell above again, you will see the model performance for a new bootstrap sample of subjects.

Exercise 4: Explore the results for a few simulated new cohorts. How do they change?#

In the third and final step of RSA, we conduct inferential comparisons between models based on their accuracy in predicting the representational dissimilarity matrices (RDMs).

We leverage the variability in the performance estimates observed in the bootstrapped samples to conduct statistical tests. These tests are designed to determine whether the differences in RDM prediction accuracy between models are statistically significant.

plot_model_comparison_trans(result)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[33], line 1
----> 1 plot_model_comparison_trans(result)

Cell In[5], line 368, in plot_model_comparison_trans(result, sort, colors, alpha, test_pair_comparisons, multiple_pair_testing, test_above_0, test_below_noise_ceil, error_bars, test_type)
    365 while len(evaluations.shape) > 2:
    366     evaluations = np.nanmean(evaluations, axis=-1)
--> 368 evaluations = evaluations[~np.isnan(evaluations[:, 0])]
    369 n_bootstraps, n_models = evaluations.shape
    370 perf = np.mean(evaluations, axis=0)

IndexError: index 0 is out of bounds for axis 1 with size 0

Details of the figure above: model comparisons: two-tailed t-test, FDR q < 0.01. Error bars indicate the standard error of the mean. One-sided comparisons of each model performance against 0 and against the lower-bound estimate of the noise ceiling are Bonferroni-corrected for the number of models.

Generalization to new images#

We have applied a method that enables us to infer how well the models might perform when predicting neural activity patterns for a new cohort of subjects. However, this approach has not yet considered the variability that could arise from introducing a new set of stimuli to the participants.

To extend our inferences to the generalizability of the models with respect to new stimuli, we must replicate the bootstrapping procedure, focusing this time on the stimuli rather than the subjects.

To do this, we will first maintain the original cohort of subjects and apply bootstrapping to resample the stimulus set.

# get the rdms for a bootstrap sample of the images
im_boot_rdms, pattern_idx = rsa.inference.bootstrap_sample_pattern(human_rdms, pattern_descriptor='index')


# plot RDMs
fig = show_rdm_plotly(im_boot_rdms.subset('roi', 'FFA'), rdm_descriptor='subject')
fig

As before, rerunning the cell above will show you the RDMs for a new set of bootstrap resampled stimuli each time.

Let’s see the inferential model comparisons based on 1000 bootstraps of the image set.

result = rsa.inference.eval_bootstrap_pattern(models, human_rdms, theta=None, method='corr', N=1000,
                           pattern_descriptor='index', rdm_descriptor='index',
                           boot_noise_ceil=True)

plot_model_comparison_trans(result)
  0%|          | 0/1000 [00:00<?, ?it/s]
  1%|          | 12/1000 [00:00<00:08, 116.71it/s]
  2%|▏         | 24/1000 [00:00<00:08, 116.67it/s]
  4%|▎         | 36/1000 [00:00<00:08, 116.91it/s]
  5%|▍         | 48/1000 [00:00<00:08, 117.01it/s]
  6%|▌         | 60/1000 [00:00<00:08, 116.40it/s]
  7%|▋         | 72/1000 [00:00<00:07, 116.38it/s]
  8%|▊         | 84/1000 [00:00<00:07, 116.67it/s]
 10%|▉         | 96/1000 [00:00<00:07, 114.75it/s]
 11%|█         | 108/1000 [00:00<00:07, 115.55it/s]
 12%|█▏        | 120/1000 [00:01<00:07, 115.94it/s]
 13%|█▎        | 132/1000 [00:01<00:07, 116.47it/s]
 14%|█▍        | 144/1000 [00:01<00:07, 116.92it/s]
 16%|█▌        | 156/1000 [00:01<00:07, 117.14it/s]
 17%|█▋        | 168/1000 [00:01<00:07, 117.38it/s]
 18%|█▊        | 180/1000 [00:01<00:06, 117.32it/s]
 19%|█▉        | 192/1000 [00:01<00:06, 117.35it/s]
 20%|██        | 204/1000 [00:01<00:06, 117.35it/s]
 22%|██▏       | 216/1000 [00:01<00:06, 117.47it/s]
 23%|██▎       | 228/1000 [00:01<00:06, 117.40it/s]
 24%|██▍       | 240/1000 [00:02<00:06, 115.51it/s]
 25%|██▌       | 252/1000 [00:02<00:06, 116.24it/s]
 26%|██▋       | 264/1000 [00:02<00:06, 116.70it/s]
 28%|██▊       | 276/1000 [00:02<00:06, 117.10it/s]
 29%|██▉       | 288/1000 [00:02<00:06, 117.37it/s]
 30%|███       | 300/1000 [00:02<00:05, 116.71it/s]
 31%|███       | 312/1000 [00:02<00:05, 116.69it/s]
 32%|███▏      | 324/1000 [00:02<00:05, 116.60it/s]
 34%|███▎      | 336/1000 [00:02<00:05, 116.80it/s]
 35%|███▍      | 348/1000 [00:02<00:05, 116.93it/s]
 36%|███▌      | 360/1000 [00:03<00:05, 114.63it/s]
 37%|███▋      | 372/1000 [00:03<00:05, 115.48it/s]
 38%|███▊      | 384/1000 [00:03<00:05, 116.15it/s]
 40%|███▉      | 396/1000 [00:03<00:05, 116.57it/s]
 41%|████      | 408/1000 [00:03<00:05, 116.93it/s]
 42%|████▏     | 420/1000 [00:03<00:04, 116.98it/s]
 43%|████▎     | 432/1000 [00:03<00:04, 116.72it/s]
 44%|████▍     | 444/1000 [00:03<00:04, 116.92it/s]
 46%|████▌     | 456/1000 [00:03<00:04, 117.21it/s]
 47%|████▋     | 468/1000 [00:04<00:04, 117.20it/s]
 48%|████▊     | 480/1000 [00:04<00:04, 115.62it/s]
 49%|████▉     | 492/1000 [00:04<00:04, 116.23it/s]
 50%|█████     | 504/1000 [00:04<00:04, 116.66it/s]
 52%|█████▏    | 516/1000 [00:04<00:04, 116.89it/s]
 53%|█████▎    | 528/1000 [00:04<00:04, 117.13it/s]
 54%|█████▍    | 540/1000 [00:04<00:03, 117.24it/s]
 55%|█████▌    | 552/1000 [00:04<00:03, 117.01it/s]
 56%|█████▋    | 564/1000 [00:04<00:03, 117.05it/s]
 58%|█████▊    | 576/1000 [00:04<00:03, 117.19it/s]
 59%|█████▉    | 588/1000 [00:05<00:03, 117.37it/s]
 60%|██████    | 600/1000 [00:05<00:03, 116.76it/s]
 61%|██████    | 612/1000 [00:05<00:03, 116.66it/s]
 62%|██████▏   | 624/1000 [00:05<00:03, 116.46it/s]
 64%|██████▎   | 636/1000 [00:05<00:03, 116.48it/s]
 65%|██████▍   | 648/1000 [00:05<00:03, 116.38it/s]
 66%|██████▌   | 660/1000 [00:05<00:02, 116.44it/s]
 67%|██████▋   | 672/1000 [00:05<00:02, 116.63it/s]
 68%|██████▊   | 684/1000 [00:05<00:02, 116.81it/s]
 70%|██████▉   | 696/1000 [00:05<00:02, 116.90it/s]
 71%|███████   | 708/1000 [00:06<00:02, 115.53it/s]
 72%|███████▏  | 720/1000 [00:06<00:02, 114.27it/s]
 73%|███████▎  | 732/1000 [00:06<00:02, 114.45it/s]
 74%|███████▍  | 744/1000 [00:06<00:02, 115.12it/s]
 76%|███████▌  | 756/1000 [00:06<00:02, 115.69it/s]
 77%|███████▋  | 768/1000 [00:06<00:01, 116.19it/s]
 78%|███████▊  | 780/1000 [00:06<00:01, 116.55it/s]
 79%|███████▉  | 792/1000 [00:06<00:01, 116.88it/s]
 80%|████████  | 804/1000 [00:06<00:01, 117.05it/s]
 82%|████████▏ | 816/1000 [00:07<00:01, 117.12it/s]
 83%|████████▎ | 828/1000 [00:07<00:01, 116.83it/s]
 84%|████████▍ | 840/1000 [00:07<00:01, 116.65it/s]
 85%|████████▌ | 852/1000 [00:07<00:01, 116.79it/s]
 86%|████████▋ | 864/1000 [00:07<00:01, 116.94it/s]
 88%|████████▊ | 876/1000 [00:07<00:01, 117.08it/s]
 89%|████████▉ | 888/1000 [00:07<00:00, 117.06it/s]
 90%|█████████ | 900/1000 [00:07<00:00, 117.04it/s]
 91%|█████████ | 912/1000 [00:07<00:00, 117.05it/s]
 92%|█████████▏| 924/1000 [00:07<00:00, 117.12it/s]
 94%|█████████▎| 936/1000 [00:08<00:00, 116.98it/s]
 95%|█████████▍| 948/1000 [00:08<00:00, 116.66it/s]
 96%|█████████▌| 960/1000 [00:08<00:00, 116.66it/s]
 97%|█████████▋| 972/1000 [00:08<00:00, 116.82it/s]
 98%|█████████▊| 984/1000 [00:08<00:00, 117.02it/s]
100%|█████████▉| 996/1000 [00:08<00:00, 117.26it/s]
100%|██████████| 1000/1000 [00:08<00:00, 116.62it/s]

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[35], line 1
----> 1 result = rsa.inference.eval_bootstrap_pattern(models, human_rdms, theta=None, method='corr', N=1000,
      2                            pattern_descriptor='index', rdm_descriptor='index',
      3                            boot_noise_ceil=True)
      5 plot_model_comparison_trans(result)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/rsatoolbox/inference/evaluate.py:324, in eval_bootstrap_pattern(models, data, theta, method, N, pattern_descriptor, rdm_descriptor, boot_noise_ceil)
    322         noise_max.append(np.nan)
    323 if boot_noise_ceil:
--> 324     eval_ok = np.isfinite(evaluations[:, 0])
    325     noise_ceil = np.array([noise_min, noise_max])
    326     variances = np.cov(np.concatenate([evaluations[eval_ok, :].T,
    327                                        noise_ceil[:, eval_ok]]))

IndexError: index 0 is out of bounds for axis 1 with size 0

Section 5. Model Comparison Using Two-factor Bootstrap#

For generalization across both the subject and stimulus populations, we can use a two-factor bootstrap method. For an in-depth discussion of this technique, refer to Schütt et al., 2023.

We can use the RSA toolbox to implement bootstrap resampling of subjects and stimuli simultaneously. It it important to note that a naive 2-factor bootstrap approach triple-counts the variance contributed by the measurement noise. For further understanding of this issue, see the explanation provided by Schütt et al. Fortunately, the RSA toolbox has an implementation that corrects for this potential overestimation.

Let’s evaluate the performance of the models with simultaneous bootstrap resampling of the subjects and stimuli.

eval_result = rsa.inference.eval_dual_bootstrap(models, fmri_rdms.subset('roi', 'FFA'), method='corr')
print(eval_result)
  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 4/1000 [00:00<00:26, 37.85it/s]
  1%|          | 8/1000 [00:00<00:26, 36.78it/s]
  1%|          | 12/1000 [00:00<00:27, 36.02it/s]
  2%|▏         | 16/1000 [00:00<00:27, 36.30it/s]
  2%|▏         | 20/1000 [00:00<00:26, 36.85it/s]
  2%|▏         | 24/1000 [00:00<00:26, 36.53it/s]
  3%|▎         | 28/1000 [00:00<00:26, 36.35it/s]
  3%|▎         | 32/1000 [00:00<00:26, 36.53it/s]
  4%|▎         | 36/1000 [00:00<00:26, 36.29it/s]
  4%|▍         | 40/1000 [00:01<00:26, 36.54it/s]
  4%|▍         | 44/1000 [00:01<00:25, 36.87it/s]
  5%|▍         | 48/1000 [00:01<00:26, 36.04it/s]
  5%|▌         | 52/1000 [00:01<00:25, 36.54it/s]
  6%|▌         | 56/1000 [00:01<00:25, 36.42it/s]
  6%|▌         | 60/1000 [00:01<00:25, 36.49it/s]
  6%|▋         | 64/1000 [00:01<00:25, 36.21it/s]
  7%|▋         | 68/1000 [00:01<00:25, 36.57it/s]
  7%|▋         | 72/1000 [00:01<00:25, 36.75it/s]
  8%|▊         | 76/1000 [00:02<00:24, 37.25it/s]
  8%|▊         | 80/1000 [00:02<00:24, 36.87it/s]
  8%|▊         | 84/1000 [00:02<00:24, 36.77it/s]
  9%|▉         | 88/1000 [00:02<00:25, 36.38it/s]
  9%|▉         | 92/1000 [00:02<00:25, 35.65it/s]
 10%|▉         | 96/1000 [00:02<00:24, 36.27it/s]
 10%|█         | 100/1000 [00:02<00:24, 36.23it/s]
 10%|█         | 104/1000 [00:02<00:24, 36.48it/s]
 11%|█         | 108/1000 [00:02<00:25, 35.56it/s]
 11%|█         | 112/1000 [00:03<00:24, 35.78it/s]
 12%|█▏        | 116/1000 [00:03<00:24, 35.94it/s]
 12%|█▏        | 120/1000 [00:03<00:24, 35.85it/s]
 12%|█▏        | 124/1000 [00:03<00:24, 35.14it/s]
 13%|█▎        | 128/1000 [00:03<00:24, 35.07it/s]
 13%|█▎        | 132/1000 [00:03<00:24, 35.09it/s]
 14%|█▎        | 136/1000 [00:03<00:24, 35.09it/s]
 14%|█▍        | 140/1000 [00:03<00:24, 35.19it/s]
 14%|█▍        | 144/1000 [00:03<00:24, 35.16it/s]
 15%|█▍        | 148/1000 [00:04<00:24, 35.47it/s]
 15%|█▌        | 152/1000 [00:04<00:23, 35.74it/s]
 16%|█▌        | 156/1000 [00:04<00:24, 35.16it/s]
 16%|█▌        | 160/1000 [00:04<00:23, 35.82it/s]
 16%|█▋        | 164/1000 [00:04<00:23, 36.08it/s]
 17%|█▋        | 168/1000 [00:04<00:23, 35.97it/s]
 17%|█▋        | 172/1000 [00:04<00:23, 35.96it/s]
 18%|█▊        | 176/1000 [00:04<00:23, 35.73it/s]
 18%|█▊        | 180/1000 [00:04<00:22, 36.14it/s]
 18%|█▊        | 184/1000 [00:05<00:22, 36.45it/s]
 19%|█▉        | 188/1000 [00:05<00:22, 36.49it/s]
 19%|█▉        | 192/1000 [00:05<00:22, 36.50it/s]
 20%|█▉        | 196/1000 [00:05<00:22, 36.28it/s]
 20%|██        | 200/1000 [00:05<00:21, 36.43it/s]
 20%|██        | 204/1000 [00:05<00:21, 36.38it/s]
 21%|██        | 208/1000 [00:05<00:21, 36.13it/s]
 21%|██        | 212/1000 [00:05<00:21, 36.20it/s]
 22%|██▏       | 216/1000 [00:05<00:22, 34.98it/s]
 22%|██▏       | 220/1000 [00:06<00:22, 35.03it/s]
 22%|██▏       | 224/1000 [00:06<00:21, 35.45it/s]
 23%|██▎       | 228/1000 [00:06<00:21, 35.88it/s]
 23%|██▎       | 232/1000 [00:06<00:21, 35.58it/s]
 24%|██▎       | 236/1000 [00:06<00:21, 35.61it/s]
 24%|██▍       | 240/1000 [00:06<00:21, 35.56it/s]
 24%|██▍       | 244/1000 [00:06<00:21, 35.34it/s]
 25%|██▍       | 248/1000 [00:06<00:20, 36.09it/s]
 25%|██▌       | 252/1000 [00:06<00:20, 36.61it/s]
 26%|██▌       | 256/1000 [00:07<00:19, 37.34it/s]
 26%|██▌       | 260/1000 [00:07<00:19, 37.06it/s]
 26%|██▋       | 264/1000 [00:07<00:19, 36.90it/s]
 27%|██▋       | 268/1000 [00:07<00:19, 36.81it/s]
 27%|██▋       | 272/1000 [00:07<00:19, 37.14it/s]
 28%|██▊       | 276/1000 [00:07<00:19, 36.33it/s]
 28%|██▊       | 280/1000 [00:07<00:19, 36.22it/s]
 28%|██▊       | 284/1000 [00:07<00:19, 36.28it/s]
 29%|██▉       | 288/1000 [00:07<00:19, 36.21it/s]
 29%|██▉       | 292/1000 [00:08<00:19, 36.49it/s]
 30%|██▉       | 296/1000 [00:08<00:19, 36.31it/s]
 30%|███       | 300/1000 [00:08<00:19, 36.42it/s]
 30%|███       | 304/1000 [00:08<00:19, 36.59it/s]
 31%|███       | 308/1000 [00:08<00:19, 36.24it/s]
 31%|███       | 312/1000 [00:08<00:18, 36.46it/s]
 32%|███▏      | 316/1000 [00:08<00:18, 36.29it/s]
 32%|███▏      | 320/1000 [00:08<00:18, 36.02it/s]
 32%|███▏      | 324/1000 [00:08<00:18, 36.16it/s]
 33%|███▎      | 328/1000 [00:09<00:18, 36.10it/s]
 33%|███▎      | 332/1000 [00:09<00:18, 35.91it/s]
 34%|███▎      | 336/1000 [00:09<00:18, 35.95it/s]
 34%|███▍      | 340/1000 [00:09<00:18, 35.95it/s]
 34%|███▍      | 344/1000 [00:09<00:18, 36.14it/s]
 35%|███▍      | 348/1000 [00:09<00:18, 35.56it/s]
 35%|███▌      | 352/1000 [00:09<00:17, 36.30it/s]
 36%|███▌      | 356/1000 [00:09<00:17, 36.00it/s]
 36%|███▌      | 360/1000 [00:09<00:17, 36.25it/s]
 36%|███▋      | 364/1000 [00:10<00:17, 35.58it/s]
 37%|███▋      | 368/1000 [00:10<00:17, 35.68it/s]
 37%|███▋      | 372/1000 [00:10<00:17, 35.85it/s]
 38%|███▊      | 376/1000 [00:10<00:17, 35.86it/s]
 38%|███▊      | 380/1000 [00:10<00:17, 36.23it/s]
 38%|███▊      | 384/1000 [00:10<00:17, 35.39it/s]
 39%|███▉      | 388/1000 [00:10<00:17, 35.71it/s]
 39%|███▉      | 392/1000 [00:10<00:16, 36.50it/s]
 40%|███▉      | 396/1000 [00:10<00:16, 36.61it/s]
 40%|████      | 400/1000 [00:11<00:16, 36.29it/s]
 40%|████      | 404/1000 [00:11<00:16, 36.34it/s]
 41%|████      | 408/1000 [00:11<00:16, 36.06it/s]
 41%|████      | 412/1000 [00:11<00:16, 36.22it/s]
 42%|████▏     | 416/1000 [00:11<00:16, 35.99it/s]
 42%|████▏     | 420/1000 [00:11<00:16, 35.50it/s]
 42%|████▏     | 424/1000 [00:11<00:16, 35.83it/s]
 43%|████▎     | 428/1000 [00:11<00:15, 36.05it/s]
 43%|████▎     | 432/1000 [00:11<00:15, 36.18it/s]
 44%|████▎     | 436/1000 [00:12<00:15, 36.18it/s]
 44%|████▍     | 440/1000 [00:12<00:15, 36.28it/s]
 44%|████▍     | 444/1000 [00:12<00:15, 36.03it/s]
 45%|████▍     | 448/1000 [00:12<00:15, 36.40it/s]
 45%|████▌     | 452/1000 [00:12<00:14, 36.66it/s]
 46%|████▌     | 456/1000 [00:12<00:15, 36.22it/s]
 46%|████▌     | 460/1000 [00:12<00:14, 36.03it/s]
 46%|████▋     | 464/1000 [00:12<00:15, 35.37it/s]
 47%|████▋     | 468/1000 [00:12<00:14, 36.08it/s]
 47%|████▋     | 472/1000 [00:13<00:14, 35.98it/s]
 48%|████▊     | 476/1000 [00:13<00:14, 35.89it/s]
 48%|████▊     | 480/1000 [00:13<00:14, 35.99it/s]
 48%|████▊     | 484/1000 [00:13<00:14, 35.94it/s]
 49%|████▉     | 488/1000 [00:13<00:14, 36.53it/s]
 49%|████▉     | 492/1000 [00:13<00:13, 37.12it/s]
 50%|████▉     | 496/1000 [00:13<00:13, 37.33it/s]
 50%|█████     | 500/1000 [00:13<00:13, 37.34it/s]
 50%|█████     | 504/1000 [00:13<00:13, 36.72it/s]
 51%|█████     | 508/1000 [00:14<00:13, 36.67it/s]
 51%|█████     | 512/1000 [00:14<00:13, 36.85it/s]
 52%|█████▏    | 516/1000 [00:14<00:12, 37.34it/s]
 52%|█████▏    | 520/1000 [00:14<00:12, 37.16it/s]
 52%|█████▏    | 524/1000 [00:14<00:12, 37.22it/s]
 53%|█████▎    | 528/1000 [00:14<00:12, 36.82it/s]
 53%|█████▎    | 532/1000 [00:14<00:12, 36.38it/s]
 54%|█████▎    | 536/1000 [00:14<00:12, 36.08it/s]
 54%|█████▍    | 540/1000 [00:14<00:12, 35.87it/s]
 54%|█████▍    | 544/1000 [00:15<00:12, 36.26it/s]
 55%|█████▍    | 548/1000 [00:15<00:12, 36.73it/s]
 55%|█████▌    | 552/1000 [00:15<00:12, 36.56it/s]
 56%|█████▌    | 556/1000 [00:15<00:12, 36.92it/s]
 56%|█████▌    | 560/1000 [00:15<00:11, 37.21it/s]
 56%|█████▋    | 564/1000 [00:15<00:11, 37.25it/s]
 57%|█████▋    | 568/1000 [00:15<00:11, 36.56it/s]
 57%|█████▋    | 572/1000 [00:15<00:11, 36.23it/s]
 58%|█████▊    | 576/1000 [00:15<00:11, 36.50it/s]
 58%|█████▊    | 580/1000 [00:16<00:11, 36.78it/s]
 58%|█████▊    | 584/1000 [00:16<00:11, 36.47it/s]
 59%|█████▉    | 588/1000 [00:16<00:11, 35.98it/s]
 59%|█████▉    | 592/1000 [00:16<00:11, 36.34it/s]
 60%|█████▉    | 596/1000 [00:16<00:11, 36.21it/s]
 60%|██████    | 600/1000 [00:16<00:11, 36.12it/s]
 60%|██████    | 604/1000 [00:16<00:10, 36.25it/s]
 61%|██████    | 608/1000 [00:16<00:10, 36.06it/s]
 61%|██████    | 612/1000 [00:16<00:10, 36.18it/s]
 62%|██████▏   | 616/1000 [00:17<00:10, 35.74it/s]
 62%|██████▏   | 620/1000 [00:17<00:10, 36.36it/s]
 62%|██████▏   | 624/1000 [00:17<00:10, 36.10it/s]
 63%|██████▎   | 628/1000 [00:17<00:10, 36.58it/s]
 63%|██████▎   | 632/1000 [00:17<00:10, 36.79it/s]
 64%|██████▎   | 636/1000 [00:17<00:09, 37.12it/s]
 64%|██████▍   | 640/1000 [00:17<00:09, 37.11it/s]
 64%|██████▍   | 644/1000 [00:17<00:09, 36.96it/s]
 65%|██████▍   | 648/1000 [00:17<00:09, 37.23it/s]
 65%|██████▌   | 652/1000 [00:17<00:09, 37.15it/s]
 66%|██████▌   | 656/1000 [00:18<00:09, 36.61it/s]
 66%|██████▌   | 660/1000 [00:18<00:09, 36.44it/s]
 66%|██████▋   | 664/1000 [00:18<00:08, 37.35it/s]
 67%|██████▋   | 668/1000 [00:18<00:08, 37.17it/s]
 67%|██████▋   | 672/1000 [00:18<00:08, 37.02it/s]
 68%|██████▊   | 676/1000 [00:18<00:08, 37.10it/s]
 68%|██████▊   | 680/1000 [00:18<00:08, 36.17it/s]
 68%|██████▊   | 684/1000 [00:18<00:08, 36.27it/s]
 69%|██████▉   | 688/1000 [00:18<00:08, 36.34it/s]
 69%|██████▉   | 692/1000 [00:19<00:08, 36.13it/s]
 70%|██████▉   | 696/1000 [00:19<00:08, 36.65it/s]
 70%|███████   | 700/1000 [00:19<00:08, 36.44it/s]
 70%|███████   | 704/1000 [00:19<00:08, 36.21it/s]
 71%|███████   | 708/1000 [00:19<00:08, 35.83it/s]
 71%|███████   | 712/1000 [00:19<00:08, 35.40it/s]
 72%|███████▏  | 716/1000 [00:19<00:08, 35.46it/s]
 72%|███████▏  | 720/1000 [00:19<00:07, 36.01it/s]
 72%|███████▏  | 724/1000 [00:19<00:07, 35.51it/s]
 73%|███████▎  | 728/1000 [00:20<00:07, 35.85it/s]
 73%|███████▎  | 732/1000 [00:20<00:07, 35.95it/s]
 74%|███████▎  | 736/1000 [00:20<00:07, 36.31it/s]
 74%|███████▍  | 740/1000 [00:20<00:07, 36.41it/s]
 74%|███████▍  | 744/1000 [00:20<00:07, 36.37it/s]
 75%|███████▍  | 748/1000 [00:20<00:06, 36.27it/s]
 75%|███████▌  | 752/1000 [00:20<00:06, 36.76it/s]
 76%|███████▌  | 756/1000 [00:20<00:06, 35.99it/s]
 76%|███████▌  | 760/1000 [00:20<00:06, 35.81it/s]
 76%|███████▋  | 764/1000 [00:21<00:06, 35.89it/s]
 77%|███████▋  | 768/1000 [00:21<00:06, 35.82it/s]
 77%|███████▋  | 772/1000 [00:21<00:06, 36.06it/s]
 78%|███████▊  | 776/1000 [00:21<00:06, 36.65it/s]
 78%|███████▊  | 780/1000 [00:21<00:06, 36.48it/s]
 78%|███████▊  | 784/1000 [00:21<00:05, 36.34it/s]
 79%|███████▉  | 788/1000 [00:21<00:05, 36.61it/s]
 79%|███████▉  | 792/1000 [00:21<00:05, 36.06it/s]
 80%|███████▉  | 796/1000 [00:21<00:05, 36.04it/s]
 80%|████████  | 800/1000 [00:22<00:05, 36.23it/s]
 80%|████████  | 804/1000 [00:22<00:05, 36.51it/s]
 81%|████████  | 808/1000 [00:22<00:05, 36.25it/s]
 81%|████████  | 812/1000 [00:22<00:05, 36.31it/s]
 82%|████████▏ | 816/1000 [00:22<00:05, 36.42it/s]
 82%|████████▏ | 820/1000 [00:22<00:04, 36.32it/s]
 82%|████████▏ | 824/1000 [00:22<00:04, 36.40it/s]
 83%|████████▎ | 828/1000 [00:22<00:04, 35.73it/s]
 83%|████████▎ | 832/1000 [00:22<00:04, 35.64it/s]
 84%|████████▎ | 836/1000 [00:23<00:04, 35.78it/s]
 84%|████████▍ | 840/1000 [00:23<00:04, 36.43it/s]
 84%|████████▍ | 844/1000 [00:23<00:04, 36.89it/s]
 85%|████████▍ | 848/1000 [00:23<00:04, 37.03it/s]
 85%|████████▌ | 852/1000 [00:23<00:04, 36.87it/s]
 86%|████████▌ | 856/1000 [00:23<00:03, 36.10it/s]
 86%|████████▌ | 860/1000 [00:23<00:03, 36.29it/s]
 86%|████████▋ | 864/1000 [00:23<00:03, 36.56it/s]
 87%|████████▋ | 868/1000 [00:23<00:03, 36.72it/s]
 87%|████████▋ | 872/1000 [00:24<00:03, 36.07it/s]
 88%|████████▊ | 876/1000 [00:24<00:03, 36.29it/s]
 88%|████████▊ | 880/1000 [00:24<00:03, 36.59it/s]
 88%|████████▊ | 884/1000 [00:24<00:03, 36.56it/s]
 89%|████████▉ | 888/1000 [00:24<00:03, 36.73it/s]
 89%|████████▉ | 892/1000 [00:24<00:02, 36.88it/s]
 90%|████████▉ | 896/1000 [00:24<00:02, 36.41it/s]
 90%|█████████ | 900/1000 [00:24<00:02, 36.96it/s]
 90%|█████████ | 904/1000 [00:24<00:02, 36.20it/s]
 91%|█████████ | 908/1000 [00:25<00:02, 35.84it/s]
 91%|█████████ | 912/1000 [00:25<00:02, 36.63it/s]
 92%|█████████▏| 916/1000 [00:25<00:02, 36.45it/s]
 92%|█████████▏| 920/1000 [00:25<00:02, 35.97it/s]
 92%|█████████▏| 924/1000 [00:25<00:02, 36.64it/s]
 93%|█████████▎| 928/1000 [00:25<00:01, 36.63it/s]
 93%|█████████▎| 932/1000 [00:25<00:01, 36.80it/s]
 94%|█████████▎| 936/1000 [00:25<00:01, 36.55it/s]
 94%|█████████▍| 940/1000 [00:25<00:01, 36.29it/s]
 94%|█████████▍| 944/1000 [00:26<00:01, 36.36it/s]
 95%|█████████▍| 948/1000 [00:26<00:01, 36.66it/s]
 95%|█████████▌| 952/1000 [00:26<00:01, 36.68it/s]
 96%|█████████▌| 956/1000 [00:26<00:01, 36.41it/s]
 96%|█████████▌| 960/1000 [00:26<00:01, 36.53it/s]
 96%|█████████▋| 964/1000 [00:26<00:00, 36.72it/s]
 97%|█████████▋| 968/1000 [00:26<00:00, 36.54it/s]
 97%|█████████▋| 972/1000 [00:26<00:00, 36.28it/s]
 98%|█████████▊| 976/1000 [00:26<00:00, 35.97it/s]
 98%|█████████▊| 980/1000 [00:27<00:00, 36.12it/s]
 98%|█████████▊| 984/1000 [00:27<00:00, 36.32it/s]
 99%|█████████▉| 988/1000 [00:27<00:00, 35.26it/s]
 99%|█████████▉| 992/1000 [00:27<00:00, 35.57it/s]
100%|█████████▉| 996/1000 [00:27<00:00, 35.78it/s]
100%|██████████| 1000/1000 [00:27<00:00, 35.77it/s]
100%|██████████| 1000/1000 [00:27<00:00, 36.26it/s]

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[36], line 1
----> 1 eval_result = rsa.inference.eval_dual_bootstrap(models, fmri_rdms.subset('roi', 'FFA'), method='corr')
      2 print(eval_result)

File /opt/hostedtoolcache/Python/3.9.19/x64/lib/python3.9/site-packages/rsatoolbox/inference/evaluate.py:143, in eval_dual_bootstrap(models, data, method, fitter, k_pattern, k_rdm, N, n_cv, pattern_descriptor, rdm_descriptor, use_correction)
    141 cv_method = 'dual_bootstrap'
    142 dof = min(data.n_rdm, data.n_cond) - 1
--> 143 eval_ok = ~np.isnan(evaluations[:, 0, 0, 0, 0])
    144 if use_correction and n_cv > 1:
    145     # we essentially project from the two points for 1 repetition and
    146     # for n_cv repetitions to infinitely many cv repetitions
    147     evals_nonan = np.mean(np.mean(evaluations[eval_ok], -2), -2)

IndexError: index 0 is out of bounds for axis 1 with size 0
plot_model_comparison_trans(eval_result)
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[37], line 1
----> 1 plot_model_comparison_trans(eval_result)

Cell In[5], line 368, in plot_model_comparison_trans(result, sort, colors, alpha, test_pair_comparisons, multiple_pair_testing, test_above_0, test_below_noise_ceil, error_bars, test_type)
    365 while len(evaluations.shape) > 2:
    366     evaluations = np.nanmean(evaluations, axis=-1)
--> 368 evaluations = evaluations[~np.isnan(evaluations[:, 0])]
    369 n_bootstraps, n_models = evaluations.shape
    370 perf = np.mean(evaluations, axis=0)

IndexError: index 0 is out of bounds for axis 1 with size 0

Further reading#

Representational Similarity Analysis (RSA)#

Kriegeskorte, N., Mur, M., & Bandettini, P. A. (2008). Representational similarity analysis-connecting the branches of systems neuroscience. Frontiers in systems neuroscience, 2, 249. link

Two-factor bootstrap#

Heiko H Schütt, Alexander D Kipnis, Jörn Diedrichsen, Nikolaus Kriegeskorte (2023) Statistical inference on representational geometries. eLife 12:e82566. link

RSA Toolbox#

RSA Toolbox documentation link

Natural Scenes Dataset and Algonauts#

Allen, E.J., St-Yves, G., Wu, Y., Breedlove, J.L., Prince, J.S., Dowdle, L.T., Nau, M., Caron, B., Pestilli, F., Charest, I., Hutchinson, J.B., Naselaris, T., Kay, K. A massive 7T fMRI dataset to bridge cognitive neuroscience and artificial intelligence. Nature Neuroscience (2021). link

Gifford AT, Lahner B, Saba-Sadiya S, Vilas MG, Lascelles A, Oliva A, Kay K, Roig G, Cichy RM. 2023. The Algonauts Project 2023 Challenge: How the Human Brain Makes Sense of Natural Scenes. arXiv preprint, arXiv:2301.03198. link

AlexNet#

Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012). Imagenet classification with deep convolutional neural networks. Advances in neural information processing systems, 25. link